function gradient = quadGradient(A, B)

% compute gradients of A*B with respect to matrix A

A_t = A.';
gradient = zeros(size(A, 1), size(B, 2), size(A, 1), size(A, 2));
for i = 1:size(A, 1)
    for j = 1:size(A, 2)
        B_row = B(j, :);
        B_col = B(:, j);
        
        B_row2 = B_row*A_t;
        B_col2 = A*B_col;
        
        gradient(i, :, i, j) = gradient(i, :, i, j) + B_row2;
        gradient(:, i, i, j) = gradient(:, i, i, j) + B_col2;
    end
end
